import torch
import math
import random
import numpy as np
from ..utils import *
from ..gradient.mifgsm import MIFGSM
import torchvision.transforms.functional as TF

class CWT(MIFGSM):
   
    def __init__(self, model_name, epsilon=16/255, alpha=1.6/255, epoch=10, decay=1., num_scale=20, num_block=2, rotation_probability=0.5, max_angle=20, range_max=1.3, targeted=False, random_start=False,
                norm='linfty', loss='crossentropy', device=None, attack='CWT', **kwargs):
        super().__init__(model_name, epsilon, alpha, epoch, decay, targeted, random_start, norm, loss, device, attack)
        self.num_scale = num_scale
        self.num_block = num_block
        self.max_angle = max_angle
        self.range_max = range_max
        print(range_max)

    def get_length(self, length, num_blocks):
        rand = np.random.uniform(size=num_blocks)
        rand_norm = np.round(rand / rand.sum() * (length - num_blocks * 5)).astype(np.int32)
        rand_norm[rand_norm < 10] = 10
        rand_norm[-1] += length - rand_norm.sum()
        return tuple(rand_norm)

    def split_image_into_blocks(self, img_tensor, num_blocks=(2, 2)):
        _, _, H, W = img_tensor.shape
        row_lengths = self.get_length(H, num_blocks[0])
        col_lengths = self.get_length(W, num_blocks[1])

        blocks = []
        start_h = 0
        for row in row_lengths:
            start_w = 0
            for col in col_lengths:
                if row > 0 and col > 0:
                    block = img_tensor[:, :, start_h:start_h + row, start_w:start_w + col]
                else:
                    block = None
                blocks.append(block)
                start_w += col
            start_h += row
        
        return blocks, row_lengths, col_lengths

    def random_crop_edges(self, blocks, H, W):
        B, C, scaled_H, scaled_W = blocks.shape
        top, left = 0, 0
        bottom, right = scaled_H, scaled_W

        num_edges = random.randint(1, 4)
        edges = ['top', 'bottom', 'left', 'right']
        selected_edges = random.sample(edges, k=num_edges)

        if 'top' in selected_edges:
            top = random.randint(0, max(0, scaled_H - H))
        if 'bottom' in selected_edges:
            bottom = top + H
        else:
            top = max(0, bottom - H)

        if 'left' in selected_edges:
            left = random.randint(0, max(0, scaled_W - W))
        if 'right' in selected_edges:
            right = left + W
        else:
            left = max(0, right - W)

        cropped_blocks = blocks[:, :, top:bottom, left:right]

        if cropped_blocks.shape[2:] != (H, W):
            cropped_blocks = torch.nn.functional.interpolate(cropped_blocks, size=(H, W), mode='bilinear', align_corners=False)

        return cropped_blocks

    def generate_scale_factors(self, num_blocks, range_min=1.0, range_max=1.3, num_choices=41):
        scale_choices = [round(range_min + i * (range_max - range_min) / (num_choices - 1), 2) for i in range(num_choices)]
        random.shuffle(scale_choices)  # 随机打乱
        return scale_choices[:num_blocks]

   


    def transform_block(self, block_batch, scale_factor, i, idx_to_rotate, angle_range=(0, 360), max_angle=20):
        if block_batch is None or block_batch.shape[2] < 2 or block_batch.shape[3] < 2:
            return block_batch

        B, C, H, W = block_batch.shape


        scaled_H = int(H / scale_factor)
        scaled_W = int(W / scale_factor)


        scaled_H = max(scaled_H, 2)
        scaled_W = max(scaled_W, 2)

        resized_blocks = torch.nn.functional.interpolate(
            block_batch, size=(scaled_H, scaled_W), mode='bilinear', align_corners=False
        )

        scaled_H = int(H * scale_factor)
        scaled_W = int(W * scale_factor)

        resized_blocks = torch.nn.functional.interpolate(
            resized_blocks, size=(scaled_H, scaled_W), mode='bilinear', align_corners=False
        )

        scaled_blocks = resized_blocks

        if i in idx_to_rotate:  
            angle = random.uniform(-max_angle, max_angle)  
            scaled_blocks = TF.rotate(
                scaled_blocks,
                angle=angle,
                expand=False,
                center=(scaled_W / 2, scaled_H / 2),
                fill=0  
            )

        transformed_blocks = scaled_blocks

        transformed_blocks = self.random_crop_edges(transformed_blocks, H, W)

        return transformed_blocks

    def transform(self, x, **kwargs):

        B, C, H, W = x.shape
        transformed_images = []

        for _ in range(self.num_scale):
            blocks, row_lengths, col_lengths = self.split_image_into_blocks(x, num_blocks=(self.num_block, self.num_block))
            

            scale_factors = self.generate_scale_factors(len(blocks), range_max=self.range_max)
            
            transformed_blocks = []

            if len(blocks) >= 2:

                idx_to_rotate = random.choice([[0, 1], [2, 3]])
            else:

                idx_to_rotate = [0]
                

            for i, block in enumerate(blocks):
                if block is not None:
                    scale_factor = scale_factors[i]

                    transformed_block = self.transform_block(block, scale_factor, i, idx_to_rotate, max_angle=self.max_angle)
                    transformed_blocks.append(transformed_block)
                else:
                    transformed_blocks.append(None)

            transformed_image = torch.zeros_like(x)
            start_h, idx = 0, 0
            for row_len in row_lengths:
                start_w = 0
                for col_len in col_lengths:
                    if transformed_blocks[idx] is not None:
                        transformed_image[:, :, start_h:start_h + row_len, start_w:start_w + col_len] = transformed_blocks[idx]
                    start_w += col_len
                    idx += 1
                start_h += row_len

            transformed_images.append(transformed_image)

        return torch.cat(transformed_images, dim=0)

    def get_loss(self, logits, label):
        """Calculate the loss"""
        return -self.loss(logits, label.repeat(self.num_scale)) if self.targeted else self.loss(logits, label.repeat(self.num_scale))

